{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3b05af3b",
   "metadata": {},
   "source": [
    "(tune-huggingface-example)=\n",
    "\n",
    "# Using |:hugging_face:| Huggingface Transformers with Tune\n",
    "\n",
    "```{image} /images/hugging.png\n",
    ":align: center\n",
    ":alt: Huggingface Logo\n",
    ":height: 120px\n",
    ":target: https://huggingface.co\n",
    "```\n",
    "\n",
    "```{contents}\n",
    ":backlinks: none\n",
    ":local: true\n",
    "```\n",
    "\n",
    "## Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19e3c389",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "This example is uses the official\n",
    "huggingface transformers `hyperparameter_search` API.\n",
    "\"\"\"\n",
    "import os\n",
    "\n",
    "import ray\n",
    "from ray import tune\n",
    "from ray.tune import CLIReporter\n",
    "from ray.tune.examples.pbt_transformers.utils import (\n",
    "    download_data,\n",
    "    build_compute_metrics_fn,\n",
    ")\n",
    "from ray.tune.schedulers import PopulationBasedTraining\n",
    "from transformers import (\n",
    "    glue_tasks_num_labels,\n",
    "    AutoConfig,\n",
    "    AutoModelForSequenceClassification,\n",
    "    AutoTokenizer,\n",
    "    Trainer,\n",
    "    GlueDataset,\n",
    "    GlueDataTrainingArguments,\n",
    "    TrainingArguments,\n",
    ")\n",
    "\n",
    "\n",
    "def tune_transformer(num_samples=8, gpus_per_trial=0, smoke_test=False):\n",
    "    data_dir_name = \"./data\" if not smoke_test else \"./test_data\"\n",
    "    data_dir = os.path.abspath(os.path.join(os.getcwd(), data_dir_name))\n",
    "    if not os.path.exists(data_dir):\n",
    "        os.mkdir(data_dir, 0o755)\n",
    "\n",
    "    # Change these as needed.\n",
    "    model_name = (\n",
    "        \"bert-base-uncased\" if not smoke_test else \"sshleifer/tiny-distilroberta-base\"\n",
    "    )\n",
    "    task_name = \"rte\"\n",
    "\n",
    "    task_data_dir = os.path.join(data_dir, task_name.upper())\n",
    "\n",
    "    num_labels = glue_tasks_num_labels[task_name]\n",
    "\n",
    "    config = AutoConfig.from_pretrained(\n",
    "        model_name, num_labels=num_labels, finetuning_task=task_name\n",
    "    )\n",
    "\n",
    "    # Download and cache tokenizer, model, and features\n",
    "    print(\"Downloading and caching Tokenizer\")\n",
    "    tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "\n",
    "    # Triggers tokenizer download to cache\n",
    "    print(\"Downloading and caching pre-trained model\")\n",
    "    AutoModelForSequenceClassification.from_pretrained(\n",
    "        model_name,\n",
    "        config=config,\n",
    "    )\n",
    "\n",
    "    def get_model():\n",
    "        return AutoModelForSequenceClassification.from_pretrained(\n",
    "            model_name,\n",
    "            config=config,\n",
    "        )\n",
    "\n",
    "    # Download data.\n",
    "    download_data(task_name, data_dir)\n",
    "\n",
    "    data_args = GlueDataTrainingArguments(task_name=task_name, data_dir=task_data_dir)\n",
    "\n",
    "    train_dataset = GlueDataset(\n",
    "        data_args, tokenizer=tokenizer, mode=\"train\", cache_dir=task_data_dir\n",
    "    )\n",
    "    eval_dataset = GlueDataset(\n",
    "        data_args, tokenizer=tokenizer, mode=\"dev\", cache_dir=task_data_dir\n",
    "    )\n",
    "\n",
    "    training_args = TrainingArguments(\n",
    "        output_dir=\".\",\n",
    "        learning_rate=1e-5,  # config\n",
    "        do_train=True,\n",
    "        do_eval=True,\n",
    "        no_cuda=gpus_per_trial <= 0,\n",
    "        evaluation_strategy=\"epoch\",\n",
    "        save_strategy=\"epoch\",\n",
    "        load_best_model_at_end=True,\n",
    "        num_train_epochs=2,  # config\n",
    "        max_steps=-1,\n",
    "        per_device_train_batch_size=16,  # config\n",
    "        per_device_eval_batch_size=16,  # config\n",
    "        warmup_steps=0,\n",
    "        weight_decay=0.1,  # config\n",
    "        logging_dir=\"./logs\",\n",
    "        skip_memory_metrics=True,\n",
    "        report_to=\"none\",\n",
    "    )\n",
    "\n",
    "    trainer = Trainer(\n",
    "        model_init=get_model,\n",
    "        args=training_args,\n",
    "        train_dataset=train_dataset,\n",
    "        eval_dataset=eval_dataset,\n",
    "        compute_metrics=build_compute_metrics_fn(task_name),\n",
    "    )\n",
    "\n",
    "    tune_config = {\n",
    "        \"per_device_train_batch_size\": 32,\n",
    "        \"per_device_eval_batch_size\": 32,\n",
    "        \"num_train_epochs\": tune.choice([2, 3, 4, 5]),\n",
    "        \"max_steps\": 1 if smoke_test else -1,  # Used for smoke test.\n",
    "    }\n",
    "\n",
    "    scheduler = PopulationBasedTraining(\n",
    "        time_attr=\"training_iteration\",\n",
    "        metric=\"eval_acc\",\n",
    "        mode=\"max\",\n",
    "        perturbation_interval=1,\n",
    "        hyperparam_mutations={\n",
    "            \"weight_decay\": tune.uniform(0.0, 0.3),\n",
    "            \"learning_rate\": tune.uniform(1e-5, 5e-5),\n",
    "            \"per_device_train_batch_size\": [16, 32, 64],\n",
    "        },\n",
    "    )\n",
    "\n",
    "    reporter = CLIReporter(\n",
    "        parameter_columns={\n",
    "            \"weight_decay\": \"w_decay\",\n",
    "            \"learning_rate\": \"lr\",\n",
    "            \"per_device_train_batch_size\": \"train_bs/gpu\",\n",
    "            \"num_train_epochs\": \"num_epochs\",\n",
    "        },\n",
    "        metric_columns=[\"eval_acc\", \"eval_loss\", \"epoch\", \"training_iteration\"],\n",
    "    )\n",
    "\n",
    "    trainer.hyperparameter_search(\n",
    "        hp_space=lambda _: tune_config,\n",
    "        backend=\"ray\",\n",
    "        n_trials=num_samples,\n",
    "        resources_per_trial={\"cpu\": 1, \"gpu\": gpus_per_trial},\n",
    "        scheduler=scheduler,\n",
    "        keep_checkpoints_num=1,\n",
    "        checkpoint_score_attr=\"training_iteration\",\n",
    "        stop={\"training_iteration\": 1} if smoke_test else None,\n",
    "        progress_reporter=reporter,\n",
    "        local_dir=\"~/ray_results/\",\n",
    "        name=\"tune_transformer_pbt\",\n",
    "        log_to_file=True,\n",
    "    )\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    import argparse\n",
    "\n",
    "    parser = argparse.ArgumentParser()\n",
    "    parser.add_argument(\n",
    "        \"--smoke-test\", default=True, action=\"store_true\", help=\"Finish quickly for testing\"\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--ray-address\",\n",
    "        type=str,\n",
    "        default=None,\n",
    "        help=\"Address to use for Ray. \"\n",
    "             'Use \"auto\" for cluster. '\n",
    "             \"Defaults to None for local.\",\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--server-address\",\n",
    "        type=str,\n",
    "        default=None,\n",
    "        required=False,\n",
    "        help=\"The address of server to connect to if using \" \"Ray Client.\",\n",
    "    )\n",
    "\n",
    "    args, _ = parser.parse_known_args()\n",
    "\n",
    "    if args.smoke_test:\n",
    "        ray.init()\n",
    "    elif args.server_address:\n",
    "        ray.init(f\"ray://{args.server_address}\")\n",
    "    else:\n",
    "        ray.init(args.ray_address)\n",
    "\n",
    "    if args.smoke_test:\n",
    "        tune_transformer(num_samples=1, gpus_per_trial=0, smoke_test=True)\n",
    "    else:\n",
    "        # You can change the number of GPUs here:\n",
    "        tune_transformer(num_samples=8, gpus_per_trial=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "\"\"\"Utilities to load and cache data.\"\"\"\n",
    "\n",
    "import os\n",
    "from typing import Callable, Dict\n",
    "import numpy as np\n",
    "from transformers import EvalPrediction\n",
    "from transformers import glue_compute_metrics, glue_output_modes\n",
    "\n",
    "\n",
    "def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:\n",
    "    \"\"\"Function from transformers/examples/text-classification/run_glue.py\"\"\"\n",
    "    output_mode = glue_output_modes[task_name]\n",
    "\n",
    "    def compute_metrics_fn(p: EvalPrediction):\n",
    "        if output_mode == \"classification\":\n",
    "            preds = np.argmax(p.predictions, axis=1)\n",
    "        elif output_mode == \"regression\":\n",
    "            preds = np.squeeze(p.predictions)\n",
    "        metrics = glue_compute_metrics(task_name, preds, p.label_ids)\n",
    "        return metrics\n",
    "\n",
    "    return compute_metrics_fn\n",
    "\n",
    "\n",
    "def download_data(task_name, data_dir=\"./data\"):\n",
    "    # Download RTE training data\n",
    "    print(\"Downloading dataset.\")\n",
    "    import urllib\n",
    "    import zipfile\n",
    "\n",
    "    if task_name == \"rte\":\n",
    "        url = \"https://dl.fbaipublicfiles.com/glue/data/RTE.zip\"\n",
    "    else:\n",
    "        raise ValueError(\"Unknown task: {}\".format(task_name))\n",
    "    data_file = os.path.join(data_dir, \"{}.zip\".format(task_name))\n",
    "    if not os.path.exists(data_file):\n",
    "        urllib.request.urlretrieve(url, data_file)\n",
    "        with zipfile.ZipFile(data_file) as zip_ref:\n",
    "            zip_ref.extractall(data_dir)\n",
    "        print(\"Downloaded data for task {} to {}\".format(task_name, data_dir))\n",
    "    else:\n",
    "        print(\n",
    "            \"Data already exists. Using downloaded data for task {} from {}\".format(\n",
    "                task_name, data_dir\n",
    "            )\n",
    "        )"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "orphan": true
 },
 "nbformat": 4,
 "nbformat_minor": 5
}